{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Seeding CaImAn with a manual mask and using a structural channel for motion correction\n",
    "\n",
    "This notebook aims to show how to use CaImAn for the following two use cases:\n",
    "\n",
    "- How to perform motion correction on a structural (red) channel and apply the inferred shifts to the functional channel (green).\n",
    "- How to use manual binary masks to seed the CNMF algorithm, including an approach for automatic segmentation of a structural channel. \n",
    "\n",
    "This notebook will only focus on these points while building upon CaImAn's demo pipeline. For a general demonstration of CaImAn's pipeline, please refer to the [general pipeline demo](./demo_pipeline.ipynb) notebook.\n",
    "\n",
    "Note that using information from a structural channel does not guarantee improved performance in terms of motion correction and cell identification. Structural channels might give poorer motion correction results depending on the SNR, and might express a different set of neurons compared to the functional channel.\n",
    "\n",
    "Dataset courtesy of Tolias lab (Baylor College of Medicine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bokeh.plotting as bpl\n",
    "from IPython import get_ipython\n",
    "import logging\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import cv2\n",
    "\n",
    "from glob import glob\n",
    "\n",
    "try:\n",
    "    cv2.setNumThreads(0)\n",
    "except():\n",
    "    pass\n",
    "\n",
    "try:\n",
    "    if __IPYTHON__:\n",
    "        get_ipython().run_line_magic('load_ext', 'autoreload')\n",
    "        get_ipython().run_line_magic('autoreload', '2')\n",
    "except NameError:\n",
    "    pass\n",
    "\n",
    "logfile = None # Replace with a path if you want to log to a file\n",
    "logger = logging.getLogger('caiman')\n",
    "# Set to logging.INFO if you want much output, potentially much more output\n",
    "logger.setLevel(logging.WARNING)\n",
    "logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')\n",
    "if logfile is not None:\n",
    "    handler = logging.FileHandler(logfile)\n",
    "else:\n",
    "    handler = logging.StreamHandler()\n",
    "handler.setFormatter(logfmt)\n",
    "logger.addHandler(handler)\n",
    "\n",
    "import caiman as cm\n",
    "from caiman.motion_correction import MotionCorrect\n",
    "from caiman.source_extraction.cnmf import cnmf as cnmf\n",
    "from caiman.source_extraction.cnmf import params as params\n",
    "from caiman.utils.utils import download_demo\n",
    "from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour\n",
    "import holoviews as hv\n",
    "\n",
    "bpl.output_notebook()\n",
    "hv.notebook_extension('bokeh')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select file(s) from structural channel to be motion corrected\n",
    "\n",
    "The `download_demo` function will download the specific file for you and return the complete path to the file which will be stored in your `caiman_data` directory. If you adapt this demo for your data make sure to pass the complete path to your file(s). Remember to pass the `fnames` variable as a list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fname_red = [download_demo('gmc_960_30mw_00001_red.tif')]   # filename to be processed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup motion correction object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fr = 30             # imaging rate in frames per second\n",
    "dxy = (1., 1.)      # spatial resolution in x and y in (um per pixel)\n",
    "                    # note the lower than usual spatial resolution here\n",
    "max_shift_um = (12., 12.)       # maximum shift in um\n",
    "patch_motion_um = (100., 100.)  # patch size for non-rigid correction in um\n",
    "\n",
    "pw_rigid = True       # flag to select rigid vs pw_rigid motion correction\n",
    "max_shifts = [int(a/b) for a, b in zip(max_shift_um, dxy)]\n",
    "strides = tuple([int(a/b) for a, b in zip(patch_motion_um, dxy)])\n",
    "overlaps = (24, 24)\n",
    "max_deviation_rigid = 3\n",
    "\n",
    "mc_dict = {\n",
    "   'fnames': fname_red,\n",
    "   'fr': fr,\n",
    "   'dxy': dxy,\n",
    "   'pw_rigid': pw_rigid,\n",
    "   'max_shifts': max_shifts,\n",
    "   'strides': strides,\n",
    "   'overlaps': overlaps,\n",
    "   'max_deviation_rigid': max_deviation_rigid,\n",
    "   'border_nan': 'copy'\n",
    "}\n",
    "\n",
    "opts = params.CNMFParams(params_dict=mc_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start a cluster and perform motion correction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "       backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mc = MotionCorrect(fname_red, dview=dview, **opts.get_group('motion'))\n",
    "mc.motion_correct(save_movie=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now apply shifts to functional channel\n",
    "We can use the motion correction object that has stored the shifts from the structural channel and apply the method `apply_shifts_movie` to the functional channel. Moreover, we can automatically store the file in memory mapped format. In this case we have a single file so we can save it directly in order 'C'. If you have multiple files, then its recommended to save them in order 'F' and then reload them and save in order 'C' for efficient downstream processing. See demo_pipeline notebook for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname_green = [download_demo('gmc_960_30mw_00001_green.tif')]\n",
    "mmap_file = mc.apply_shifts_movie(fname_green, save_memmap=True, order='C')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segmentation and Seeded CNMF\n",
    "\n",
    "We can use the registered structural channel to seed the CNMF algorithm. For example the mean image of the structural channel already shows a lot of neurons:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "R = cm.load(mc.mmap_file)\n",
    "mR = R.mean(0)\n",
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(mR)\n",
    "plt.title('Mean Image for Structural Channel');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Segmentation of the mean image\n",
    "CaImAn has an openCV based function for segmenting images like that. You specify the parameter `gSig` which corresponds to an approximation of the radius of the average neuron. The output of this function is a binary matrix with dimensions # of pixels x # of components that can then be used to seed the CNMF algorithm. Each column of the matrix represents the mask of an individual component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Ain = cm.base.rois.extract_binary_masks_from_structural_channel(mR, gSig=7, expand_method='dilation')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_plot_contour(mR, Ain.astype('float32'), mR.shape[0], mR.shape[1], thr=0.99)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now run CaImAn batch (CNMF) seeded with the set of the binary masks\n",
    "\n",
    "To use seeded initialization we pass the set of binary masks in the variable `Ain` when constructing the CNMF object.\n",
    "\n",
    "The two main parameter changes for the seeded compared to the standard CNMF run are `rf = None` (since the component detection is substituted by the mask and is not run in patches anymore) and `only_init = False` (since the initialization is already being done and deconvolution is necessary). The other parameters can be carried over from the standard pipeline.\n",
    "\n",
    "Additionally, the expected half size of neurons in pixels `gSig` is very important to be close to the data. ROI detection and separation of the mask as well as CNN evaluation are greatly influenced by `gSig`.\n",
    "\n",
    "For a general explanation of the parameters see the [general](./demo_pipeline.ipynb) notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# dataset dependent parameters\n",
    "rf = None                   # half-size of the patches in pixels. Should be `None` when seeded CNMF is used.\n",
    "only_init = False           # has to be `False` when seeded CNMF is used\n",
    "gSig = (7, 7)               # expected half size of neurons in pixels, very important for proper component detection\n",
    "\n",
    "# params object\n",
    "opts_dict = {'fnames': fname_green,\n",
    "            'decay_time': 0.4,\n",
    "            'p': 1,\n",
    "            'nb': 2,\n",
    "            'rf': rf,\n",
    "            'only_init': only_init,\n",
    "            'gSig': gSig,\n",
    "            'ssub': 1,\n",
    "            'tsub': 1,\n",
    "            'merge_thr': 0.85}\n",
    "\n",
    "opts.change_params(opts_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load the memory mapped file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Yr, dims, T = cm.load_memmap(mmap_file)\n",
    "images = np.reshape(Yr.T, [T] + list(dims), order='F') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Restart the cluster to clean up some memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%% restart cluster to clean up memory\n",
    "cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Now construct the object and fit\n",
    "Note how the matrix of masks is passed in the object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_seeded = cnmf.CNMF(n_processes, params=opts, dview=dview, Ain=Ain)\n",
    "cnm_seeded.fit(images);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot contours against the correlation image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CI = cm.local_correlations(images.transpose(1,2,0))     #  correlation image\n",
    "CI[np.isnan(CI)] = 0\n",
    "cnm_seeded.estimates.plot_contours_nb(img=CI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the contour plot above we see that neurons with active spots in the correlation image have been selected by the algorithm. However, there is also a lot of other components with no clear mark in the correlation image. This could corresponds to neurons that are inactive during the experiment or are simply not expressed in the correlation image. These components tend to have non-spiking traces as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_seeded.estimates.hv_view_components(img=CI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Component filtering\n",
    "We can apply our quality tests to filter out these components. Components are evaluated by the same criteria:\n",
    "\n",
    "- the shape of each component must be correlated with the data at the corresponding location within the FOV\n",
    "- a minimum peak SNR is required over the length of a transient\n",
    "- each shape passes a CNN based classifier\n",
    "\n",
    "For the seeded CNMF, this is necessary to eliminate falsely marked ROIs, especially when ROI selection is performed on structural or mean-intensity images.\n",
    "\n",
    "Experience showed that the CNN classifier might not perform well on manually selected spatial components. If a lot of high-SNR components are rejected, try running the estimation without the CNN (`'use_cnn':False`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters for component evaluation\n",
    "min_SNR = 1.5               # signal to noise ratio for accepting a component\n",
    "rval_thr = 0.8              # space correlation threshold for accepting a component\n",
    "min_cnn_thr = 0.99          # threshold for CNN based classifier\n",
    "cnn_lowest = 0.05           # neurons with cnn probability lower than this value are rejected\n",
    "#cnm_seeded.estimates.restore_discarded_components()\n",
    "cnm_seeded.params.set('quality', {'min_SNR': min_SNR,\n",
    "                           'rval_thr': rval_thr,\n",
    "                           'use_cnn': True,\n",
    "                           'min_cnn_thr': min_cnn_thr,\n",
    "                           'cnn_lowest': cnn_lowest})\n",
    "\n",
    "#%% COMPONENT EVALUATION\n",
    "# the components are evaluated in three ways:\n",
    "#   a) the shape of each component must be correlated with the data\n",
    "#   b) a minimum peak SNR is required over the length of a transient\n",
    "#   c) each shape passes a CNN based classifier\n",
    "\n",
    "cnm_seeded.estimates.evaluate_components(images, cnm_seeded.params, dview=dview)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm_seeded.estimates.plot_contours_nb(img=CI, idx=cnm_seeded.estimates.idx_components)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## accepted components\n",
    "cnm_seeded.estimates.hv_view_components(img=CI, idx=cnm_seeded.estimates.idx_components)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## rejected components\n",
    "cnm_seeded.estimates.hv_view_components(img=CI, idx=cnm_seeded.estimates.idx_components_bad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparison with (unseeded) CaImAn batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For comparison, we can also run the unseeded CNMF. For this we have to change the parameters rf and only_init\n",
    "opts.change_params({'rf': 48, 'K': 12, 'merge_parallel': True})\n",
    "cnm = cnmf.CNMF(n_processes, params= opts, dview=dview)\n",
    "cnm.fit(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cnm.estimates.restore_discarded_components()\n",
    "cnm.params.set('quality', {'min_SNR': min_SNR,\n",
    "                           'rval_thr': rval_thr,\n",
    "                           'use_cnn': True,\n",
    "                           'min_cnn_thr': min_cnn_thr,\n",
    "                           'cnn_lowest': cnn_lowest})\n",
    "\n",
    "#%% COMPONENT EVALUATION\n",
    "# the components are evaluated in three ways:\n",
    "#   a) the shape of each component must be correlated with the data\n",
    "#   b) a minimum peak SNR is required over the length of a transient\n",
    "#   c) each shape passes a CNN based classifier\n",
    "\n",
    "cnm.estimates.evaluate_components(images, cnm.params, dview=dview)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.plot_contours_nb(img=CI, idx=cnm.estimates.idx_components)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select only high quality components and compare two approaches\n",
    "\n",
    "This updates the `estimates` object by only keeping the components indexed by idx_components. Components from idx_components_bad are removed. This makes indexing the `estimates` object as seen in the cells above unnecessary and non-effective. Rejected components can be recovered with the `restore_discarded_components()` function, as long as `select_components()` has not been called with `save_discarded_components=False`.\n",
    "\n",
    "The set of accepted components is largely the same in both cases. We can quantify this comparison by using the `register_ROIs` function to register the components against each other. Before doing that, the spatial footprints need also to be thresholded to remove any faint tails."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.select_components(use_object=True)   # select only the accepted components\n",
    "cnm_seeded.estimates.select_components(use_object=True)\n",
    "cnm.estimates.threshold_spatial_components()\n",
    "cnm_seeded.estimates.threshold_spatial_components()\n",
    "res = cm.base.rois.register_ROIs(cnm.estimates.A_thr, cnm_seeded.estimates.A_thr, CI.shape,\n",
    "                                 align_flag=False, plot_results=True)\n",
    "print(res[4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Seeding CaImAn online algorithm\n",
    "The same procedure can also be followed for the CaImAn online algorithm. Seeding can be done by passing the binary mask in the `estimates.A` field of the `OnACID` object and then calling the `fit` method:\n",
    "```\n",
    "cnm_on = cnmf.online_cnmf.OnACID(params=opts)\n",
    "cnm_on.estimates.A = Ain\n",
    "cnm_on.fit_online()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some info on creating your own binary seeding mask\n",
    "\n",
    "To run a seeded CNMF, we have to provide the algorithm with directions on where it has to look for neurons. This can be done by creating a binary mask with the same dimensions as the movie, which marks possible neuron locations with `True` or `1`.\n",
    "\n",
    "To manually determine neuron locations, a template on which to find neurons has to be created. This template can be constructed with two different methods:\n",
    "1. **Structural channel and average intensity:** If your recording incorporates a calcium-independent structural channel, it can be used to extract locations of neurons by averaging the intensity of each pixel over time. This method can also be applied to the calcium-dependent channel itself. The averaging greatly reduces noise, but any temporal component is eliminated, and it is impossible to tell whether a neuron was active or inactive during the recording. Thus, many neurons selected through this technique will be false positives, which should be filtered out during the component evaluation.\n",
    "2. **Local correlations:** The arguably more accurate template can be the local correlation image of the calcium-dependent signal. Here, each pixel's value is not determined by it's intensity, but by its intensity **correlation** to its neighbors. Thus, groups of pixels that change intensity together will be brighter. This method incorporates the temporal component of the signal and accentuates firing structures like neurons and dendrites. Features visible in the local correlation image are likely functional units (such as neurons), which is what we are ultimately interested in. The number of false positives should be lower than in method 1, as structural features visible in mean-intensity images are greatly reduced. Additionally, it reduces the donut shape of some somata, making neuron detection easier.\n",
    "\n",
    "A binary seeding mask from this template can be created automatically or manually:\n",
    "1. **Automatic:** The CaImAn function `extract_binary_masks_from_structural_channel()` does basically what it says. It extracts binary masks from a movie or an image with the Adaptive Gaussian Thresholding algorithm provided by the OpenCV library. If the function is provided a movie, it applies the thresholding on the average intensity image of this movie.\n",
    "2. **Manual:** Raw movies, average intensity images or local correlation images can be loaded into Fiji to manually select ROIs. Average intensity images can be created from a raw movie with through `Image → Stacks → Z Project...`. Manually creating a binary mask in Fiji is described in the next cell. \n",
    "\n",
    "This notebook demonstrates the seeding with a mask manually created from a local correlations image. The automatic method is displayed, but commented out to avoid confusion. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import ndimage as ndi\n",
    "from skimage.io import imsave, imread\n",
    "from skimage.morphology import dilation\n",
    "from skimage.segmentation import find_boundaries, watershed\n",
    "\n",
    "# Create local correlations image and save it\n",
    "# swap_dim should be set to True if the movie is in MATLAB format (time is last instead of first axis)\n",
    "Cn = cm.local_correlations(images, swap_dim=False)\n",
    "Cn[np.isnan(Cn)] = 0\n",
    "save_path = fnames[0].split('.')[0] + '_local_corr.png'\n",
    "imsave(save_path,Cn)\n",
    "\n",
    "'''\n",
    "# Create image from structural or functional channel and check automatically extracted contours\n",
    "fname = \"filepath\"\n",
    "Ain, mR = cm.base.rois.extract_binary_masks_from_structural_channel(\n",
    "    cm.load(fname), expand_method='dilation', selem=np.ones((1, 1)))\n",
    "plt.figure()\n",
    "cm.utils.visualization.plot_contours(Ain.astype('float32'), mR, thr=0.99, display_numbers=False)\n",
    "plt.title('Contour plots of detected ROIs in the structural channel')\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have our local correlation image, we can manually select ROIs with Fiji. For this, open the image in Fiji/ImageJ and draw the outline of a potential neuron with the `Freehand selections` tool. Then, add the current selection to the ROI manager with `t` (Mac and Windows). Please note that for proper ROI detection, make sure that features cut off at the edge have a ROI that touches the edge as well.\n",
    "\n",
    "Neurons in the imaged tissue can lie in different depths, thus their ROIs and calcium signals can overlap.\n",
    "The best solution to account for this is to create one mask for each component. ROIs can be saved as binary masks through `Edit → Selection → Create mask`. Fiji's macro function can help to automatically save individual ROIs as binary images. This preserves original neuron shapes and potential overlaps. CaImAn recognizes pixels that are shared by multiple components and takes this into account when extracting the calcium signal from these regions. It is thus recommended to retain overlapping ROIs when seeding the CNMF.\n",
    "\n",
    "Alternatively, individual selections can be combined in Fiji and exported together. Then, overlapping ROIs will be separated afterwards with a watershed algorithm (see below)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extracting spatial components from binary masks\n",
    "The CNMF object stores spatial components in a sparse column format (`csc_matrix`) with n rows (number of pixels of the flattened image) and m columns (number of components). Thus, single ROIs have to be extracted and transformed into individual binary masks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading separate masks for each component (recommended)\n",
    "\n",
    "If each component ROI has been saved in its own mask, they can be loaded and transformed into the `csc_matrix` format individually. This method is recommended, as it retains the original shapes of neurons, and the algorithm takes pixels shared across multiple components into account during signal extraction.\n",
    "\n",
    "In this demo it is assumed that individual binary .png masks are stored in one directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose expand method and morphological element of expansion. Dilation of masks provides more room for error\n",
    "# and ensures that the complete neuron is included\n",
    "expand_method = 'dilation'\n",
    "selem = np.ones(3,3)\n",
    "\n",
    "# Set the path of the individual masks. Adapt file extension if masks are not saved as .png\n",
    "dir_path = 'path of directory with masks'\n",
    "file_list = glob(dir_path+'/*.png')\n",
    "\n",
    "for i in range(len(file_list)):\n",
    "    # temporarily save the current mask as a boolean array\n",
    "    temp = np.asarray(imageio.imread(file_list[i]), dtype=bool)\n",
    "    \n",
    "    # the csc_matrix has to be initialized before adding the first mask\n",
    "    if i == 0:\n",
    "        A = np.zeros((np.prod(temp.shape), len(file_list)), dtype=bool)\n",
    "        \n",
    "    # apply dilation or closing to the mask (optional)\n",
    "    if expand_method == 'dilation':\n",
    "        temp = dilation(temp, selem=selem)\n",
    "    elif expand_method == 'closing':\n",
    "        temp = dilation(temp, selem=selem)\n",
    "\n",
    "    # flatten the mask to a 1D array and add it to the csc_matrix\n",
    "    A[:, i] = temp.flatten('F')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading one mask that includes all components (not recommended)\n",
    "If the masks for all components have been saved in one file, some more processing is required to separate ROIs for each component. Since now overlapping components result in fused ROIs, they have to be separated with a watershed algorithm before being passed into the csc_matrix. This is not recommended because information of which pixels are shared by multiple components is lost, which leads to contamination of the signal from one component to the other at shared pixels. Furthermore, the additional processing step decreases the precision due to imperfect ROI separation.\n",
    "\n",
    "If you have separate masks and do not need this step, skip the next cells until `Perform seeded CNMF`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Identify individual features\n",
    "To run the watershed algorithm, we have to label each individual component. To separate overlapping features, a distance matrix can be calculated, where pixels get a value representing their distance to the nearest edge. This creates local maxima in the center of features which are not overlapping with nearby features. To account for irregularly shaped cells that would create more than one local maximum, the expected cell diameter `gSig_seed` is used to threshold local maxima, merging local maxima that are closer together than the cell radius."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = [16, 10]\n",
    "\n",
    "# load binary mask\n",
    "maskfile = os.path.join(cm.paths.caiman_datadir(), 'example_movies', 'avg_mask_fixed.png')\n",
    "mask = np.asarray(imread(maskfile), dtype=bool)\n",
    "\n",
    "# calculate distances from nearest edge\n",
    "distances = ndi.distance_transform_edt(mask)\n",
    "\n",
    "# apply threshold of expected cell diameter gSig_seed to get one maximum per cell\n",
    "gSig_seed = 20\n",
    "local_max = distances.copy()\n",
    "local_max[local_max >= gSig_seed/2] = gSig_seed\n",
    "local_max[local_max < gSig_seed/2] = 0\n",
    "local_max = local_max.astype('bool')\n",
    "\n",
    "# visualize original mask, distance matrix and thresholded local maxima\n",
    "fig, ax = plt.subplots(1,3)\n",
    "plt.sca(ax[0])\n",
    "plt.imshow(mask)\n",
    "plt.gca().set_title('Original mask')\n",
    "plt.sca(ax[1])\n",
    "plt.imshow(distances)\n",
    "plt.gca().set_title('Distance matrix')\n",
    "plt.sca(ax[2])\n",
    "plt.imshow(local_max)\n",
    "plt.gca().set_title('Local maxima')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Separate components with watershed algorithm\n",
    "Now that we have the location of separated features, but we still have to apply this separation to the original mask to preserve its shape as much as possible. For this we label the local maxima individually to prepare them for the watershed algorithm. Some irregularly shaped cells will have multiple maxima, even after the thresholding done in the previous step. We should thus remove very small features. These cleaned-up labels can then be used to apply watershed to the original image. Now the watershed algorithm knows where the individual components are and calculates smooth boundaries between them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate labels of isolated maxima\n",
    "markers = ndi.label(local_max)[0]\n",
    "\n",
    "# remove very small features (avoids irregular cells being counted multiple times)\n",
    "sizes = np.bincount(markers.ravel())      # get list of number of pixels of each label\n",
    "mask_sizes = sizes > 5                    # remove labels with very low pixel counts\n",
    "mask_sizes[0] = 0                         # remove count of background label\n",
    "local_max_cleaned = mask_sizes[markers]   # apply mask to binary image to only keep large components\n",
    "\n",
    "# update labels with cleaned-up features\n",
    "markers_cleaned = ndi.label(local_max_cleaned)[0]\n",
    "\n",
    "# apply watershed to the original binary mask using the cleaned labels to create separated ROIs\n",
    "labels = watershed(-distances, markers_cleaned, mask=mask)\n",
    "\n",
    "# visualize separated labels after watershed (each feature has a slightly different color)\n",
    "plt.imshow(labels)\n",
    "plt.title('Labels of features separated by watershed algorithm')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bring labels into sparse column format that CNMF can use\n",
    "Now we have each component marked with a different label, even if their ROIs are overlapping or merged. We can use this matrix to bring it into the sparse column format that the CNMF uses to store spatial components. Each of the labelled features will be extracted separately and used as the spatial component for this neuron."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize sparse column matrix\n",
    "num_features = np.max(labels)\n",
    "A = np.zeros((labels.size, num_features), dtype=bool)\n",
    "\n",
    "# save each component individually into the sparse column matrix\n",
    "for i in range(num_features):\n",
    "    temp = (labels == i + 1)                     # each feature is saved as a single component in its own frame\n",
    "    temp = dilation(temp, selem=np.ones((3,3)))  # dilate spatial component to increase error margin a bit\n",
    "\n",
    "    # parse the current component 'temp' into the sparse column matrix\n",
    "    A[:, i] = temp.flatten('F')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "caiman_pytorch2",
   "language": "python",
   "name": "caiman_pytorch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
